In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
from functools import partial
import numpy as np
import torch
import e3nn
from spherical import plot_data_on_grid, SphericalTensor, projection
import e3nn.o3 as o3
import e3nn.rs as rs

import plotly
import plotly.graph_objects as go

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import math

Examples of RadialModels for plotting

In [3]:
def ConstantRadialModel():
    def radial_function(r):
        shape = r.shape
        return torch.ones(list(shape) + [1])
    return radial_function

def FixedCosineRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    step = radii[1] - radii[0]

    def radial_function(r):
        shape = r.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        centers = radii.reshape(*radial_shape)
        return (r.unsqueeze(-1) - centers).div(step).add(1).relu().sub(2).neg().relu().add(1).mul(math.pi / 2).cos().pow(2)
    
    return radial_function

def FixedGaussianRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    gamma = 1. / spacing
    
    def radial_function(r):
        shape = r.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        centers = radii.reshape(*radial_shape)
        return torch.exp(-gamma * (r.unsqueeze(-1) - centers) ** 2)
    
    return radial_function

Set up coordinates for tetrahedra and set lmax

In [4]:
tetra_coords = torch.tensor( # The easiest way to construct a tetrahedron is using opposite corners of a box
    [[0., 0., 0.], [1., 1., 0.], [1., 0., 1.], [0., 1., 1.]]
)
tetra_coords -= tetra_coords.mean(-2)

lmax = 3

Create and plot spherical harmonic projection with radial functions

In [5]:
n_radial = 3
max_radius = 2.
sphten = SphericalTensor.from_geometry_with_radial(tetra_coords, FixedCosineRadialModel(max_radius, n_radial), lmax)
x, f = sphten.plot_with_radial(5.)
In [6]:
plot_max = float(f.abs().max()) * 0.5
trace = go.Volume(
    x=x[:,0], y=x[:,1], z=x[:,2], value=f,
    isomin=-plot_max,
    isomax=plot_max,
    opacity=0.3, # needs to be small to see through all surfaces
    surface_count=10, # needs to be a large number for good volume rendering
    colorscale='RdBu'
)
go.Figure([trace])

Create and plot spherical harmonic projection with magnitude as radius

In [7]:
sphten = SphericalTensor.from_geometry(tetra_coords, lmax)
trace = sphten.plot(n=50, relu=True)
go.Figure(trace)
In [ ]: